import torch
from torch.nn import functional as F
from torch import nn
from pytorch_lightning.core.lightning import LightningModule
import pytorch_lightning as pl
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import pickle
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
import json
import argparse
import sys
import matplotlib.pyplot as plt
import pandas as pd
from IPython.core.display import HTML
import sys
sys.path = sys.path[:8]
sys.path
['/home/jupyter/nwp-downscale/notebooks', '/opt/conda/envs/ilan/lib/python39.zip', '/opt/conda/envs/ilan/lib/python3.9', '/opt/conda/envs/ilan/lib/python3.9/lib-dynload', '', '/opt/conda/envs/ilan/lib/python3.9/site-packages', '/opt/conda/envs/ilan/lib/python3.9/site-packages/IPython/extensions', '/home/jupyter/.ipython']
config_path = '../experiments/eval_leingan_ens10_temperature_input.json'
args = json.load(open(config_path))
parser = argparse.ArgumentParser(args)
parser.set_defaults(**args)
args, _ = parser.parse_known_args()
folder = args.save_hparams["save_dir"] + args.save_hparams["run_name"] + str(args.save_hparams["run_number"])+"/"
leingan_ens10_temp_metrics = pickle.load(open(folder+"eval_metrics.pkl", "rb"))
leingan_ens10_temp_preds = plt.imread(folder+'sample_predictions.png', format='png')
config_path = '../experiments/eval_leingan_single_forecast_input.json'
args = json.load(open(config_path))
parser = argparse.ArgumentParser(args)
parser.set_defaults(**args)
args, _ = parser.parse_known_args()
folder = args.save_hparams["save_dir"] + args.save_hparams["run_name"] + str(args.save_hparams["run_number"])+"/"
leingan_single_forecast_metrics = pickle.load(open(folder+"eval_metrics.pkl", "rb"))
leingan_single_forecas_preds = plt.imread(folder+'sample_predictions.png', format='png')
config_path = '../experiments/eval_wgan-gp_single_forecast.json'
args = json.load(open(config_path))
parser = argparse.ArgumentParser(args)
parser.set_defaults(**args)
args, _ = parser.parse_known_args()
folder = args.save_hparams["save_dir"] + args.save_hparams["run_name"] + str(args.save_hparams["run_number"])+"/"
wgan_gp_metrics = pickle.load(open(folder+"eval_metrics.pkl", "rb"))
wgan_gp_preds = plt.imread(folder+'sample_predictions.png', format='png')
config_path = '../experiments/eval_wgan-gp-smoothed_single_forecast.json'
args = json.load(open(config_path))
parser = argparse.ArgumentParser(args)
parser.set_defaults(**args)
args, _ = parser.parse_known_args()
folder = args.save_hparams["save_dir"] + args.save_hparams["run_name"] + str(args.save_hparams["run_number"])+"/"
wgan_gp_smoothed_metrics = pickle.load(open(folder+"eval_metrics.pkl", "rb"))
wgan_gp_smoothed_preds = plt.imread(folder+'sample_predictions.png', format='png')
scalar_metrics = pd.DataFrame(columns=["model", "crps", "avg_pool_crps", "max_pool_crps", "rmse"])
models = [('leingan_single_forecast', leingan_single_forecast_metrics), ('leingan_ens10_temp', leingan_ens10_temp_metrics), ('wgan_gp', wgan_gp_metrics),('wgan_gp_smoothed', wgan_gp_smoothed_metrics) ]
for (model_name, metrics) in models:
scalar_metrics = scalar_metrics.append({"model": model_name, "crps":metrics["crps"],
"avg_pool_crps":metrics["avg_pool_crps"],
"max_pool_crps":metrics["max_pool_crps"],
"rmse":metrics["rmse"]}, ignore_index=True)
scalar_metrics
| model | crps | avg_pool_crps | max_pool_crps | rmse | |
|---|---|---|---|---|---|
| 0 | leingan_single_forecast | 0.356860 | 0.349406 | 0.595922 | 1.024945 |
| 1 | leingan_ens10_temp | 0.378926 | 0.374730 | 0.664519 | 1.213675 |
| 2 | wgan_gp | 0.361144 | 0.351502 | 0.594240 | 0.993455 |
| 3 | wgan_gp_smoothed | 0.333920 | 0.324778 | 0.565999 | 0.996977 |
for (model_name, metrics) in models:
plt.plot(metrics['rankhist']/np.sum(metrics['rankhist']), label=model_name)
plt.hlines(y = 1/31, xmin=0, xmax=31, color='k', linestyle='--')
plt.xlabel('rank')
plt.ylabel('normalised frequency')
plt.title('Rank Histogram')
plt.legend()
<matplotlib.legend.Legend at 0x7fb6ecb929d0>
forecast_probs = wgan_gp['reliability'][1]
for (model_name, metrics) in models:
try:
relative_freq, forecast_probs, samples = metrics['reliability']
except:
relative_freq, samples = metrics['reliability']
plt.plot(forecast_probs, relative_freq, label=model_name)
plt.plot(np.linspace(0, 0.95, len(forecast_probs)), np.linspace(0, 0.95, len(forecast_probs)), 'k--')
plt.xlabel('Forecast Probability')
plt.ylabel('Observed Frequency')
plt.title('Reliability Diagram')
plt.legend()
<matplotlib.legend.Legend at 0x7fb6eca57850>
plt.figure(figsize=(20, 45))
plt.imshow(leingan_ens10_temp_preds)
<matplotlib.image.AxesImage at 0x7fbd2663ddc0>
plt.figure(figsize=(20, 45))
plt.imshow(leingan_single_forecast_preds)
<matplotlib.image.AxesImage at 0x7fbd266405b0>
plt.figure(figsize=(20, 45))
plt.imshow(wgan_gp_preds)
<matplotlib.image.AxesImage at 0x7fbd14d6a220>
plt.figure(figsize=(20, 45))
plt.imshow(wgan_gp_smoothed_preds)
<matplotlib.image.AxesImage at 0x7fbd14d58340>